from __future__ import annotations

import argparse
import math

import utils

from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy

from attacker import Attacker
from client import Client
from defender import Defender

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()

    # data
    parser.add_argument('--dataset', type=str, choices=['cifar10', 'femnist', 'imagenet12' ], required=True, )
    parser.add_argument('--data_root', type=str, required=True)
    parser.add_argument('--n_classes', type=int, )
    # data partition
    parser.add_argument('--n_clients', type=int, )
    parser.add_argument('--dirichlet_beta', default=0.5, type=float, )
    parser.add_argument('--dirichlet_min_n_data', default=10, type=int, )

    # attack
    parser.add_argument('--byz_client_ratio', type=float, required=True, )
    parser.add_argument('--skew_lambda', default=1, type=float, )

    # server
    parser.add_argument('--architecture', type=str, choices=['alexnet', 'resnet18', 'squeezenet', ], )
    parser.add_argument('--server_n_epochs', type=int, )
    # aggregator
    parser.add_argument('--bucket_s', default=1, type=int)
    parser.add_argument('--aggregator', type=str, choices=['aksel', 'cc', 'dnc', 'krum', 'median', 'rfa', 'trmean', ], required=True, )
    parser.add_argument('--cc_l', default=1, type=int, )
    parser.add_argument('--cc_tau', default=10, type=float, )
    parser.add_argument('--dnc_n_iters', default=1, type=int, )
    parser.add_argument('--dnc_b', default=1000, type=int, help='dimension of subsamples')
    parser.add_argument('--dnc_c', default=1.0, type=float, help='filtering fraction', )
    parser.add_argument('--rfa_budget', default=8, type=int, )
    parser.add_argument('--rfa_eps', default=1e-7, type=float, )

    # client
    parser.add_argument('--client_batch_size', type=int, )
    parser.add_argument('--client_lr', type=float, )
    parser.add_argument('--client_momentum', type=float, )
    parser.add_argument('--client_weight_decay', type=float, )
    parser.add_argument('--client_n_epochs', type=int, )
    parser.add_argument('--clip_max_norm', default=None, type=float)

    # random
    parser.add_argument('--seed', type=int, default=0)

    # efficiency
    parser.add_argument('--pin_memory', action='store_true', default=False, )
    parser.add_argument('--n_workers', default=1, type=int, )
    parser.add_argument('--te_batch_size', default=128, type=int, )

    args = parser.parse_args()

    # set default config
    if args.dataset == 'femnist':
        from configs import femnist_config as config
    elif args.dataset == 'cifar10':
        from configs import cifar10_config as config
    elif args.dataset == 'imagenet12':
        from configs import imagenet12_config as config
    else:
        raise Exception('invalid dataset')
    for name, val in vars(config).items():
        if getattr(args, name, None) is None:
            setattr(args, name, val)

    return args
def update_args(args, n_clients: int):
    setattr(args, 'n_clients', n_clients)
    # compute number of byzantine clients
    setattr(args, 'n_byz_clients', math.ceil(args.n_clients * args.byz_client_ratio))

    return args

def run_fl(args):
    utils.setup_seed(args.seed)
    # load data
    client_datas, te_data = utils.load_data(args.dataset, args.data_root, args.n_clients, args.dirichlet_beta, args.dirichlet_min_n_data)
    update_args(args, len(client_datas))
    # record arguments
    print(f'args: {args}')
    # initialize client-side
    clients = {str(idx): Client(client_datas[idx], args) for idx in range(len(client_datas))}
    client_idxs = {str(idx) for idx in range(len(client_datas))}
    byz_client_idxs = {str(idx) for idx in range(args.n_byz_clients)}
    ben_client_idxs = client_idxs.difference(byz_client_idxs)
    # initialize attacker
    byz_clients = {idx: clients[idx] for idx in byz_client_idxs}
    attacker = Attacker(byz_clients, args=args)
    # initialize server-side
    global_model = utils.get_model(args.architecture, args.n_classes)
    defender = Defender(args, byz_client_idxs)
    # test
    te_dataloader = DataLoader(dataset=te_data, batch_size=args.te_batch_size, shuffle=False, num_workers=args.n_workers, pin_memory=args.pin_memory)
    loss_fun = cross_entropy

    # start
    for epoch in range(args.server_n_epochs):
        # sample clients
        # distribute
        server_message = {'model_state': global_model.state_dict()}
        # benign clients perform local update
        ben_client_msgs = {idx: clients[idx].local_update(server_message) for idx in ben_client_idxs}
        # attack
        defender_knowledge = {'n_byz_updates': len(byz_client_idxs)}
        attacker_knowledge = {
            'benign_client_messages': ben_client_msgs,
            'defender': defender, 
            'defender_knowledge': defender_knowledge, 
        }
        byz_client_msgs, byz_verbose_log = attacker.attack(byz_client_idxs, server_message, attacker_knowledge)
        # mergetorch slice
        sampled_client_msgs = {**ben_client_msgs, **byz_client_msgs}
        # aggregate
        agg_update = defender.defend(sampled_client_msgs, defender_knowledge)
        # step
        utils.step(agg_update, global_model)
        # evaluate
        perf_stats = utils.eval_perf(global_model, te_dataloader, loss_fun)
        # output
        print(f'==================== epoch {epoch} ====================')
        utils.print_dict(byz_verbose_log, 'skew')
        utils.print_dict(perf_stats, f'performance {epoch}')

    print('==================== end of training ====================')

if __name__ == '__main__':
    # parse arguments
    args = parse_args()
    # fl training
    run_fl(args)